#!/bin/env python3
import ithildin as ith
import numpy as np
import sys
from typing import List
from scipy.interpolate import RectBivariateSpline
from pydiffmap.diffusion_map import DiffusionMap
from pydiffmap import visualization as diff_visualization
import argparse
import pickle
import gc

parser = argparse.ArgumentParser(description="Compute a diffusion map.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('stem', help='stem for the ithildin simulation data')
parser.add_argument('diffmap', help='file name to save the diffusion map as')
parser.add_argument('-k', metavar='k', type=int, help='number of neirest neighbours to use', default=64)
parser.add_argument('--eps', metavar='k', type=float, default=1)
parser.add_argument('--points', type=int, help='number of sampled points to use', default=2048)
parser.add_argument('--normalize', action=argparse.BooleanOptionalAction, help='normalize data before constructing diffusion maps', default=True)
parser.add_argument('--space-clip-percentage', type=int, help='percentage to ignore in each not-small space dimension (for eliminating border effects)', default=25)
parser.add_argument('--small-treshold', type=int, help='number of grid points in a space dimension from which it is considered non-small', default=20)
parser.add_argument('--start-time', type=int, help='time to start at (percentage-wise)', default=10)
parser.add_argument('--end-time', type=int, help='time to end at (percentage-wise)', default=100)
args = parser.parse_args()

print("Loading simulation data from %s ..." % (args.stem,))
data = ith.SimData.from_stem(args.stem)

def apply_percentage(p, x):
    assert(0 <= p)
    assert(p <= 100)
    assert(0 <= x)
    return (p * x)//100 # the exact rounding is expected to only have neglible results

def start_percentage(dimension):
    if dimension==0: # time dimension
        return args.start_time
    else: # space dimension
        return args.space_clip_percentage//2

def end_percentage(dimension):
    if dimension==0: # time dimension
        return args.end_time
    else: # space dimension
        if data.shape[dimension] <= 1:
            return 100
        # TODO: different, use round-upwards division elsewhere
        else:
            return 100 - (args.space_clip_percentage//2)

def make_borders(start_percentage, end_percentage):
    shape = list(data.shape)
    for i in range(0,len(shape)):
        shape[i] = slice(apply_percentage(start_percentage(i),shape[i]), apply_percentage(end_percentage(i),shape[i]))
    return tuple(shape)

def remove_all_borders(variable):
    slices = make_borders(start_percentage, end_percentage)
    return variable[slices]

print('Preprocessing input data ...')
vars = dict()
# Remove borders from each variable.
for key in data.vars.keys():
    vars[key] = remove_all_borders(data.vars[key])
# Use the phase space, not the configuration space.
original_number_of_points = None
for key in data.vars.keys():
    vars[key] = vars[key].ravel()
    original_number_of_points = vars[key].shape[0]

# Reduce the number of points, for computational reasons
np.random.seed(1) # determinism
choices = np.random.choice(original_number_of_points, args.points, replace=False)
for key in data.vars.keys():
    vars[key] = vars[key][choices]
del choices
gc.collect() # save some memory

# Normalise data, if requested.
means = dict()
spreads = dict()
for key in data.vars.keys():
    means[key] = np.mean(vars[key])
    spreads[key] = np.std(vars[key])
    # K is constant for some of the models, don't divide by zero
    if args.normalize and spreads[key] != 0:
        vars[key] -= means[key]
        vars[key] /= spreads[key]
        gc.collect()

# Compute the diffusion map
dmap = DiffusionMap.from_sklearn(epsilon=args.eps,k=args.k,n_evecs=len(vars.keys()))
variable_names = list(vars.keys())
variable_names.sort()
variable_values = list(variable_names)
for i in range(0,len(variable_names)):
    variable_values[i] = vars[variable_names[i]]
variable_values = np.column_stack(tuple(variable_values))
print(variable_values.shape)
#del vars # try reducing peak memory usage (not very effective?)
#gc.collect()
print('Computing the diffusion map ...')
dmap.fit(variable_values)
print('Computed!')

# Save the diffusion map
print(dir(dmap))
dmap_info = dict()
dmap_info['variable_names'] = variable_names
dmap_info['preprocessed_data'] = dmap.data
dmap_info['means'] = means
dmap_info['spreads'] = spreads
dmap_info['alpha'] = dmap.alpha
dmap_info['n_evecs'] = dmap.n_evecs
dmap_info['evecs'] = dmap.evecs
dmap_info['evals'] = dmap.evals
dmap_info['right_norm_vec'] = dmap.right_norm_vec
dmap_info['epsilon_fitted'] = dmap.epsilon_fitted
dmap_info['weights'] = dmap.weights
dmap_info['kernel_matrix'] = dmap.kernel_matrix
dmap_info['L'] = dmap.L
dmap_info['q'] = dmap.q
dmap_info['dmap'] = dmap.dmap
dmap_info['normalize'] = args.normalize
dmap_info['start_time'] = args.start_time
dmap_info['end_time'] = args.end_time
dmap_info['space_clip_percentage'] = args.space_clip_percentage
dmap_info['local_kernel'] = dict(k=dmap.local_kernel.k,neigh=dmap.local_kernel.neigh,epsilon=dmap.local_kernel.epsilon,epsilon_fitted=dmap.local_kernel.epsilon_fitted)

with open(args.diffmap, 'wb') as f:
    pickle.dump(dmap_info,f)
